{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "from pathlib import Path\n",
    "import os\n",
    "\n",
    "SAVE_PATH = Path(os.path.abspath(\"\")).parent / \"_static\"\n",
    "SAVE_PATH.mkdir(exist_ok=True, parents=True)\n",
    "RESULTS_PATH =  Path(os.path.abspath(\"\")).parent / \"results\"\n",
    "FEDAVG_PATH = RESULTS_PATH / \"fedavg\" / \"results.pickle\"\n",
    "STATAVG_PATH = RESULTS_PATH / \"statavg\" / \"results.pickle\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_accuracy(stat_avg_path: Path, fedavg_path: Path) -> None:\n",
    "    \"\"\"Plot the accuracy.\"\"\"\n",
    "    with open(stat_avg_path, \"rb\") as file:\n",
    "        statavg_results = pickle.load(file)\n",
    "\n",
    "    with open(fedavg_path, \"rb\") as file:\n",
    "        fedavg_results = pickle.load(file)\n",
    "\n",
    "    fig, ax = plt.subplots(1,1, figsize=(12,8))\n",
    "    for results, label in [(statavg_results, \"StatAvg\"), (fedavg_results, \"FedAvg\")]:\n",
    "        accuracy_dict = results[\"history\"].metrics_distributed\n",
    "        accuracy_lst = accuracy_dict[\"accuracy\"]\n",
    "\n",
    "        rounds = [p[0] for p in accuracy_lst]\n",
    "        acc = [p[1] for p in accuracy_lst]\n",
    "\n",
    "        ax.plot(rounds, acc, marker=\"o\", linestyle=\"-\", label=label)\n",
    "    ax.legend(fontsize=14)\n",
    "    ax.set_xlabel(\"Rounds\", fontsize=14)\n",
    "    ax.set_ylabel(\"Testing Accuracy\", fontsize=14)\n",
    "    ax.tick_params(axis='both', labelsize=14)\n",
    "\n",
    "    ax.grid(True)\n",
    "    fig.show()\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plot_accuracy(stat_avg_path=STATAVG_PATH, fedavg_path=FEDAVG_PATH)\n",
    "def saveFig(name, fig):\n",
    "    fig.savefig(\n",
    "        name,\n",
    "        dpi=None,\n",
    "        facecolor=fig.get_facecolor(),\n",
    "        edgecolor=\"none\",\n",
    "        orientation=\"portrait\",\n",
    "        format=\"png\",\n",
    "        transparent=False,\n",
    "        bbox_inches=\"tight\",\n",
    "        pad_inches=0.2,\n",
    "        metadata=None,\n",
    "    )\n",
    "saveFig(SAVE_PATH/\"results.png\", fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "statavg",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
